from ctypes import c_int16
from io import SEEK_SET
import io
from io import BytesIO
import math
from s4studio.core import Serializable, PackedResource, ExternalResource, ResourceKey
from s4studio.helpers import Flag
from s4studio.io import StreamReader, StreamWriter, TGIList, StreamPtr, RCOL
from s4studio.model.geometry import Mesh, Vertex, SkinController
from s4studio.model.material import MaterialBlock


class Blend:
    class Vertex(object):
        __slots__ = {
            'position',
            'normal',
            'id'
        }

        def __init__(self):
            self.position = None
            self.normal = None
            self.id = 0

    class LOD(object):
        def __init__(self):
            self.vertices = []

    def __init__(self):
        self.age_gender_flags = 0
        self.blend_region = 0
        self.lods = []


class SlotPose(RCOL):
    class Item(object):
        def __init__(self, name=None, offset=None, scale=None, rotation=None):
            if offset == None:
                offset = [0.0] * 3
            if scale == None:
                scale = [1.0] * 3
            if rotation == None:
                rotation = [0.0, 0.0, 0.0, 1.0]
            self.bone_name = None
            self.offset = offset
            self.scale = scale
            self.rotation = rotation

    ID = 0x0355E0A6

    def __init__(self, key=None, stream=None):
        self.version = 0x00000000
        self.deltas = []
        RCOL.__init__(self, key, stream)

    def read_rcol(self, stream, rcol):
        s = StreamReader(stream)
        self.version = s.u32()
        self.deltas = [
            self.Item(s.u32(), [s.f32() for i in range(3)], [s.f32() for i in range(3)], [s.f32() for i in range(4)])]

    def write_rcol(self, stream, rcol):
        s = StreamWriter(stream)
        s.u32(self.version)
        s.i32(len(self.deltas))
        for delta in self.deltas:
            s.hash(delta.bone_name)
            for i in range(3): s.f32(delta.offset[i])
            for i in range(3): s.f32(delta.scale[i])
            for i in range(4): s.f32(delta.rotation[i])


class BodyGeometry(RCOL, Mesh):
    TAG = 'GEOM'
    ID = 0x015A1849
    BLOCK_ID = 0x00000000

    class VertexFormat(Serializable):
        def __init__(self, stream=None):
            self.declarations = []
            Serializable.__init__(self, stream)

        def read(self, stream, resource=None):
            s = StreamReader(stream)
            self.declarations = [self.Declaration(stream) for i in range(s.i32())]

        def write(self, stream, resource=None):
            s = StreamWriter(stream)
            s.i32(len(self.declarations))
            for declaration in self.declarations: declaration.write(stream)

        def from_elements(self, position=1, normal=1, uv=1, blend_indices=1, blend_weights=1, tangents=1, color=1,
                          id=1):
            v = Vertex()
            if position > 0:
                v.position = [0] * 3
            if normal > 0:
                v.normal = [0] * 3
            if uv > 0:
                v.uv = [[0] * 2 for i in range(uv)]
            if blend_indices > 0:
                v.blend_indices = [1] * 4
            if blend_weights > 0:
                v.blend_weights = [0, 0, 255, 0]
            if tangents > 0:
                v.tangent = [0, 0, 0]
            if color > 0:
                v.colour = [0] * 4
            if id > 0:
                v.id = 1
            self.from_vertex(v)
            pass

        def from_vertex(self, vertex):
            self.declarations = []
            assert isinstance(vertex, Vertex)
            if vertex.position: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.POSITION, self.Declaration.FORMAT.FLOAT, 12))
            if vertex.normal: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.NORMAL, self.Declaration.FORMAT.FLOAT, 12))
            if vertex.uv:
                for uv in vertex.uv:
                    self.declarations.append(
                        self.Declaration(None, self.Declaration.USAGE.UV, self.Declaration.FORMAT.FLOAT, 8))
            if vertex.colour: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.COLOUR, self.Declaration.FORMAT.ARGB, 4))
            if vertex.blend_indices: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.BLEND_INDEX, self.Declaration.FORMAT.BYTE, 4))
            if vertex.blend_weights: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.BLEND_WEIGHT, self.Declaration.FORMAT.BYTE, 4))
            if vertex.tangent: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.TANGENT, self.Declaration.FORMAT.FLOAT, 12))
            if vertex.id: self.declarations.append(
                self.Declaration(None, self.Declaration.USAGE.ID, self.Declaration.FORMAT.UINT, 4))

        class Declaration(Serializable):
            class USAGE:
                POSITION = 0x00000001
                NORMAL = 0x00000002
                UV = 0x00000003
                BLEND_INDEX = 0x00000004
                BLEND_WEIGHT = 0x00000005
                TANGENT = 0x00000006
                COLOUR = 0x00000007
                ID = 0x0000000A

            class FORMAT:
                FLOAT = 0x00000001
                BYTE = 0x00000002
                ARGB = 0x00000003
                UINT = 0x00000004

            def __init__(self, stream=None, usage=0, format=0, size=0):
                self.usage = usage
                self.format = format
                self.size = size
                Serializable.__init__(self, stream)

            def read(self, stream, resource=None):
                s = StreamReader(stream)
                self.usage = s.u32()
                self.format = s.u32()
                self.size = s.u8()

            def write(self, stream, resource=None):
                s = StreamWriter(stream)
                s.u32(self.usage)
                s.u32(self.format)
                s.u8(self.size)

    class UnknownS4A(Serializable):
        def __init__(self, stream=None, resources=None):
            self.unknown1 = 0
            self.unknown2 = []
            super().__init__(stream, resources)

        def read(self, stream, resources=None):
            s = StreamReader(stream)
            self.unknown1 = s.u32()
            self.unknown2 = []
            c = s.i32()
            for i in range(c):
                self.unknown2.append([s.f32(), s.f32()])

        def write(self, stream, resources=None):
            s = StreamWriter(stream)
            s.u32(self.unknown1)
            s.i32(len(self.unknown2))
            for i in self.unknown2:
                for j in i:
                    s.f32(j)

    class UnknownS4B(Serializable):
        def __init__(self, stream=None, resources=None):
            self.unknown1 = 0
            self.unknown2 = [0] * 3
            self.unknown3 = [0.0] * 13
            self.unknown4 = 0xFF
            super().__init__(stream, resources)

        def read(self, stream, resources=None):
            s = StreamReader(stream)
            self.unknown1 = s.u32()
            self.unknown2 = []
            for i in range(3):
                self.unknown2.append(s.u16())
            self.unknown3 = []
            for i in range(13):
                self.unknown3.append(s.f32())
            self.unknown4 = s.u8()

        def write(self, stream, resources=None):
            s = StreamWriter(stream)
            s.u32(self.unknown1)
            for i in self.unknown2:
                s.u16(i)
            for i in self.unknown3:
                s.f32(i)
            s.u8(self.unknown4)


    def __init__(self, key=None):
        RCOL.__init__(self, key)
        self.version = 0
        self.indices = []
        self.vertices = []
        self.shader = None
        self.material = []
        self.merge_group = 0
        self.skin_controller = ExternalResource(ResourceKey(BodySkinController.ID))
        self.sort_order = 0
        self.vertex_format = self.VertexFormat()
        self.bones = []
        self.unknown1 = []
        self.unknown2 = []
        self.tgi = []

    def min_vertex_id(self):
        if not any(self.vertices) or not any(filter(lambda f: f.usage == BodyGeometry.VertexFormat.Declaration.USAGE.ID,
                                                    self.vertex_format.declarations)):
            return 0
        return min((vertex.id for vertex in self.vertices if vertex.id != None))

    def read_rcol(self, stream, rcol):
        s = StreamReader(stream)
        self.read_tag(stream)
        self.version = s.u32()
        tgi = TGIList()
        tgi.begin_read(stream)
        self.tgi = tgi.blocks
        self.shader = s.u32()
        if self.shader:
            # end_material = s.u32() + stream.tell()
            # self.material = MaterialBlock()
            #
            # self.material.read(stream, tgi)
            #
            # if stream.tell() != end_material:
            #     stream.seek(end_material, SEEK_SET)
            self.material = stream.read(s.u32())

        self.merge_group = s.u32()
        self.sort_order = s.u32()
        cVertices = s.u32()
        self.vertex_format.read(stream)

        def read_vertex_data(declaration, s):
            assert isinstance(declaration, self.VertexFormat.Declaration)
            assert isinstance(s, StreamReader)
            bytes_per_element = 0
            if declaration.format in [declaration.FORMAT.FLOAT, declaration.FORMAT.UINT]:
                bytes_per_element = 4
            elif declaration.format in [declaration.FORMAT.ARGB, declaration.FORMAT.BYTE]:
                bytes_per_element = 1
            element_count = declaration.size / bytes_per_element
            data = []
            for i in range(int(element_count)):
                element = None
                if declaration.format in [declaration.FORMAT.ARGB, declaration.FORMAT.BYTE]:
                    element = s.u8()
                elif declaration.format in [declaration.FORMAT.FLOAT]:
                    element = s.f32()
                elif declaration.format in [declaration.FORMAT.UINT]:
                    element = s.u32()
                else:
                    pass
                data.append(element)

            return data


        for vertex_index in range(cVertices):
            vertex = Vertex()
            for declaration in self.vertex_format.declarations:
                element = read_vertex_data(declaration, s)
                if declaration.usage == self.VertexFormat.Declaration.USAGE.POSITION:
                    vertex.position = element
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.NORMAL:
                    vertex.normal = element
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.UV:
                    uv = element
                    if vertex.uv == None:
                        vertex.uv = []
                    vertex.uv.append(uv)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.BLEND_INDEX:
                    vertex.blend_indices = element
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.BLEND_WEIGHT:
                    vertex.blend_weights = element
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.TANGENT:
                    vertex.tangent = element
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.COLOUR:
                    vertex.colour = element
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.ID:
                    vertex.id = element
                else:
                    print('???')
            self.vertices.append(vertex)
        item_count = s.u32()
        bytes_per_index = s.u8()
        assert bytes_per_index == 2
        self.indices = [[s.u16() for i in range(3)] for i in range(int(s.u32() / 3))]
        if self.version == 0x05:
            self.skin_controller = tgi.get_resource(s.u32())
        else:
            self.unknown1 = []
            c = s.i32()
            for i in range(c):
                self.unknown1.append(self.UnknownS4A(stream))
            self.unknown2 = []
            c = s.i32()
            for i in range(c):
                self.unknown2.append(self.UnknownS4B(stream))

        self.bones = [s.u32() for i in range(s.u32())]
        tgi.end_read(stream)

    def write_rcol(self, stream, rcol):
        s = StreamWriter(stream)
        self.write_tag(stream)
        s.u32(self.version)
        tgi = TGIList()
        tgi.begin_write(stream)

        for t in self.tgi:
            tgi.get_resource_index(t)
        if self.shader:
            s.hash(self.shader)
            with io.BytesIO() as material_stream:
                s.u32(len(self.material))
                stream.write(self.material)

                # self.material.write(material_stream, tgi)
                # length= material_stream.tell()
                # s.u32(length)
                # material_stream.seek(0,SEEK_SET)
                # s.bytes(material_stream.read())
                pass
        else:
            s.u32(0)

        s.u32(self.merge_group)
        s.u32(self.sort_order)
        s.u32(len(self.vertices))
        self.vertex_format.from_vertex(self.vertices[0])
        self.vertex_format.write(stream)

        def write_vertex_data(declaration, s, data):
            # print(declaration)
            # assert isinstance(declaration, self.VertexFormat.Declaration)
            assert isinstance(s, StreamWriter)
            bytes_per_element = 0
            if declaration.format in [declaration.FORMAT.FLOAT, declaration.FORMAT.UINT]:
                bytes_per_element = 4
            elif declaration.format in [declaration.FORMAT.ARGB, declaration.FORMAT.BYTE]:
                bytes_per_element = 1
            assert len(data) == declaration.size / bytes_per_element
            for element in data:
                if declaration.format in [declaration.FORMAT.ARGB, declaration.FORMAT.BYTE]:
                    s.u8(element)
                elif declaration.format in [declaration.FORMAT.FLOAT]:
                    s.f32(element)
                elif declaration.format in [declaration.FORMAT.UINT]:
                    s.u32(element)
                else:
                    raise Exception("Unknown format")

        for vertex_index,vertex in enumerate(self.vertices):
            uv_index = 0
            for declaration in self.vertex_format.declarations:
                if declaration.usage == self.VertexFormat.Declaration.USAGE.POSITION:
                    write_vertex_data(declaration, s, vertex.position)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.NORMAL:
                    write_vertex_data(declaration, s, vertex.normal)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.UV:
                    try:
                        write_vertex_data(declaration, s, vertex.uv[uv_index])
                        uv_index += 1
                    except Exception as e:
                        print('[%s] uv_index: %s uv:%s vertex:%s' % (vertex_index, uv_index, vertex.uv, vertex))
                        raise e
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.BLEND_INDEX:
                    write_vertex_data(declaration, s, vertex.blend_indices)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.BLEND_WEIGHT:
                    write_vertex_data(declaration, s, vertex.blend_weights)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.TANGENT:
                    write_vertex_data(declaration, s, vertex.tangent)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.COLOUR:
                    write_vertex_data(declaration, s, vertex.colour)
                elif declaration.usage == self.VertexFormat.Declaration.USAGE.ID:
                    write_vertex_data(declaration, s, vertex.id)
        s.u32(1)
        s.i8(2)
        s.u32(len(self.indices)*3)
        for polygon in self.indices:
            for index in polygon:
                s.u16(index)

        if self.version == 0x00000005:
            s.u32(tgi.get_resource_index(self.skin_controller))
        else:
            s.i32(len(self.unknown1))
            for u in self.unknown1:
                u.write(stream)
            s.i32(len(self.unknown2))
            for u in self.unknown2:
                u.write(stream)
        s.u32(len(self.bones))
        for bone in self.bones: s.u32(bone)
        tgi.end_write(stream,True)

    def get_vertices(self):
        return self.vertices

    def get_triangles(self):
        return self.indices

    def __str__(self):
        return "%s : Vertices:(%i) Faces:(%i)" % (PackedResource.__str__(self), len(self.vertices), len(self.indices))

class BlendGeometrySims4(RCOL):
    def __init__(self,key=None,stream=None,resources=None):
        self.version = 0
        self.bgeo=None
        RCOL.__init__(self,key,stream,resources)

    def read_rcol(self, stream, rcol):

        self.bgeo = BlendGeometry()
        self.bgeo.read(stream)
        pass
    pass
class BlendGeometry(PackedResource):
    ID = 0x067CAA11
    TAG = 'BGEO'

    class VERSION:
        STANDARD = 0x00000300
        OTHER = 0x00030000
        S4 = 0x00000600

    def unpack(self, packed):
        return c_int16(packed ^ 0x8000).value / 2000.0

    def pack(self, unpacked):
        return c_int16(math.floor(unpacked * 2000.0) ^ 0x8000).value

    class BlendVertex:
        __slots__ = {
            'position',
            'normal'
        }

        def __init__(self):
            self.position = None
            self.normal = None

    class VertexPtr:
        FLAG_HAS_POSITION = 0x00000001
        FLAG_HAS_NORMAL = 0x00000002

        def __init__(self, val):
            self.value = val

        def get_offset(self):
            return self.value >> 2

        def set_offset(self, value):
            self.value = (value << 2) + (self.value & 0x00000003)

        offset = property(get_offset, set_offset)

        def get_has_position(self):
            return Flag.is_set(self.value, self.FLAG_HAS_POSITION)

        def set_has_position(self, value):
            self.value = Flag.set(self.value, self.FLAG_HAS_POSITION) if value else Flag.unset(self.value,
                                                                                               self.FLAG_HAS_POSITION)

        has_position = property(get_has_position, set_has_position)

        def get_has_normal(self):
            return Flag.is_set(self.value, self.FLAG_HAS_NORMAL)

        def set_has_normal(self, value):
            self.value = Flag.set(self.value, self.FLAG_HAS_NORMAL) if value else Flag.unset(self.value,
                                                                                             self.FLAG_HAS_NORMAL)

        has_normal = property(get_has_normal, set_has_normal)

    class LodPtr:
        def __init__(self, start_vertex_id, vertex_count, vector_count):
            self.start_vertex_id = start_vertex_id
            self.vertex_count = vertex_count
            self.vector_count = vector_count

        def __str__(self):
            return "0x%08X 0x%08X 0x%08X" % (self.start_vertex_id, self.vertex_count, self.vector_count)

    def __init__(self, key=None, stream=None):
        self.blends = []
        self.version = self.VERSION.STANDARD
        PackedResource.__init__(self, key, stream)

    def write(self, stream, resource=None):
        s = StreamWriter(stream)
        s.chars(self.TAG)
        s.u32(self.version)
        cBlends = len(self.blends)
        cLods = 0
        if cBlends:
            len(self.blends[0].lods)
            # TODO: write this crazy thing


    def read(self, stream, resource=None):
        s = StreamReader(stream)
        a=stream.tell()
        assert s.chars(4) == self.TAG
        self.version = s.u32()

        cBlends = s.i32() if self.version < self.VERSION.S4 else 1
        cLods = s.i32()
        cPointers = s.i32()
        cVectors = s.i32()
        assert s.i32() == 0x00000008
        assert s.i32() == 0x0000000C
        blend_ptr = StreamPtr.begin_read(s)
        vertex_ptr = StreamPtr.begin_read(s)
        vector_ptr = StreamPtr.begin_read(s)
        blend_ptr.end()
        lod_ptrs = []
        for blend_index in range(cBlends):
            blend = Blend()
            blend.age_gender_flags = s.u32()
            blend.blend_region = s.u32()
            self.blends.append(blend)
            blend.lods = [Blend.LOD() for lod_index in range(cLods)]
            lod_ptrs.append([self.LodPtr(s.u32(), s.u32(), s.u32()) for lod_index in range(cLods)])

        vertex_ptr.end()
        pointers = [self.VertexPtr(s.i16()) for pointer_index in range(cPointers)]
        vector_ptr.end()
        vectors = [[self.unpack(s.i16()) for i in range(3)] for vector_index in range(cVectors)]

        for blend_index, blend in enumerate(self.blends):
            start_vector_ptr = 0
            current_vector_offset = 0
            blend_ptr = lod_ptrs[blend_index]
            for lod_index, lod in enumerate(blend.lods):
                lod_blend_index = blend_index + lod_index
                if lod_blend_index >= len(blend_ptr):
                    print('Skipping missing LOD %s - %s' % (lod_blend_index, len(blend_ptr)))
                    continue
                lod_ptr = blend_ptr[blend_index + lod_index]
                current_vertex_id = lod_ptr.start_vertex_id
                for vector_ptr_index in range(lod_ptr.vertex_count):
                    vertex = Blend.Vertex()
                    vector_ptr = pointers[vector_ptr_index + start_vector_ptr]
                    current_vector_offset += vector_ptr.offset
                    vertex.id = current_vertex_id
                    vertex_vector_offset = 0
                    if vector_ptr.has_position:
                        vertex.position = vectors[current_vector_offset + vertex_vector_offset]
                        vertex_vector_offset += 1
                    if vector_ptr.has_normal:
                        vertex.normal = vectors[current_vector_offset + vertex_vector_offset]
                        vertex_vector_offset += 1
                    current_vertex_id += 1
                    lod.vertices.append(vertex)
                start_vector_ptr += lod_ptr.vertex_count
                current_vector_offset += lod_ptr.vector_count


class BodySkinController(SkinController, PackedResource):
    ID = 0x00AE6C67

    def __init__(self, key):
        PackedResource.__init__(self, key)
        SkinController.__init__(self)
        self.version = 0

    def read(self, stream, resource=None):
        s = StreamReader(stream)
        self.version = s.u32()
        names = [s.s7(16, '>') for i in range(s.u32())]
        poses = [[[s.f32() for j in range(3)] for i in range(4)] for pose_index in range(s.u32())]
        self.bones = [self.Bone(names[i], pose) for i, pose in enumerate(poses)]

    def write(self, stream, resource=None):
        s = StreamWriter(stream)
        s.u32(self.version)
        s.u32(len(self.bones))
        for bone in self.bones:
            s.s7(bone.name, 16, '>')
        s.u32(len(self.bones))
        for bone in self.bones:
            for i in range(3):
                for j in range(4):
                    s.f32(bone[i][j])
